# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Error classes for the GenAI SDK."""

from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import httpx
import json
from . import _common


if TYPE_CHECKING:
  from .replay_api_client import ReplayResponse
  import aiohttp


class APIError(Exception):
  """General errors raised by the GenAI API."""
  code: int
  response: Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse']

  status: Optional[str] = None
  message: Optional[str] = None

  def __init__(
      self,
      code: int,
      response_json: Any,
      response: Optional[
          Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse']
      ] = None,
  ):
    if isinstance(response_json, list) and len(response_json) == 1:
      response_json = response_json[0]

    self.response = response
    self.details = response_json
    self.message = self._get_message(response_json)
    self.status = self._get_status(response_json)
    self.code = code if code else self._get_code(response_json)

    super().__init__(f'{self.code} {self.status}. {self.details}')

  def __reduce__(
      self,
  ) -> tuple[Callable[..., 'APIError'], tuple[dict[str, Any]]]:
    """Returns a tuple that can be used to reconstruct the error for pickling."""
    state = self.__dict__.copy()
    return (self.__class__._rebuild, (state,))

  @staticmethod
  def _rebuild(state: dict[str, Any]) -> 'APIError':
    """Rebuilds the error from the state."""
    obj = APIError.__new__(APIError)
    obj.__dict__.update(state)
    Exception.__init__(obj, f'{obj.code} {obj.status}. {obj.details}')
    return obj

  def _get_status(self, response_json: Any) -> Any:
    return response_json.get(
        'status', response_json.get('error', {}).get('status', None)
    )

  def _get_message(self, response_json: Any) -> Any:
    return response_json.get(
        'message', response_json.get('error', {}).get('message', None)
    )

  def _get_code(self, response_json: Any) -> Any:
    return response_json.get(
        'code', response_json.get('error', {}).get('code', None)
    )

  def _to_replay_record(self) -> _common.StringDict:
    """Returns a dictionary representation of the error for replay recording.

    details is not included since it may expose internal information in the
    replay file.
    """
    return {
        'error': {
            'code': self.code,
            'message': self.message,
            'status': self.status,
        }
    }

  @classmethod
  def raise_for_response(
      cls, response: Union['ReplayResponse', httpx.Response]
  ) -> None:
    """Raises an error with detailed error message if the response has an error status."""
    if response.status_code == 200:
      return

    if isinstance(response, httpx.Response):
      try:
        response.read()
        response_json = response.json()
      except json.decoder.JSONDecodeError:
        message = response.text
        response_json = {
            'message': message,
            'status': response.reason_phrase,
        }
    else:
      response_json = response.body_segments[0].get('error', {})

    cls.raise_error(response.status_code, response_json, response)

  @classmethod
  def raise_error(
      cls,
      status_code: int,
      response_json: Any,
      response: Optional[
          Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse']
      ],
  ) -> None:
    """Raises an appropriate APIError subclass based on the status code.

    Args:
      status_code: The HTTP status code of the response.
      response_json: The JSON body of the response, or a dict containing error
        details.
      response: The original response object.

    Raises:
      ClientError: If the status code is in the 4xx range.
      ServerError: If the status code is in the 5xx range.
      APIError: For other error status codes.
    """
    if 400 <= status_code < 500:
      raise ClientError(status_code, response_json, response)
    elif 500 <= status_code < 600:
      raise ServerError(status_code, response_json, response)
    else:
      raise cls(status_code, response_json, response)

  @classmethod
  async def raise_for_async_response(
      cls,
      response: Union[
          'ReplayResponse', httpx.Response, 'aiohttp.ClientResponse'
      ],
  ) -> None:
    """Raises an error with detailed error message if the response has an error status."""
    status_code = 0
    response_json = None
    if isinstance(response, httpx.Response):
      if response.status_code == 200:
        return
      try:
        await response.aread()
        response_json = response.json()
      except json.decoder.JSONDecodeError:
        message = response.text
        response_json = {
            'message': message,
            'status': response.reason_phrase,
        }
      status_code = response.status_code
    elif hasattr(response, 'body_segments') and hasattr(
        response, 'status_code'
    ):
      if response.status_code == 200:
        return
      response_json = response.body_segments[0].get('error', {})
      status_code = response.status_code
    else:
      try:
        import aiohttp  # pylint: disable=g-import-not-at-top

        if isinstance(response, aiohttp.ClientResponse):
          if response.status == 200:
            return
          try:
            response_json = await response.json()
          except aiohttp.client_exceptions.ContentTypeError:
            message = await response.text()
            response_json = {
                'message': message,
                'status': response.reason,
            }
          status_code = response.status
        else:
          raise ValueError(f'Unsupported response type: {type(response)}')
      except ImportError:
        raise ValueError(f'Unsupported response type: {type(response)}')

    await cls.raise_error_async(status_code, response_json, response)

  @classmethod
  async def raise_error_async(
      cls, status_code: int, response_json: Any, response: Optional[
          Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse']
      ]
  ) -> None:
    """Raises an appropriate APIError subclass based on the status code.

    Args:
      status_code: The HTTP status code of the response.
      response_json: The JSON body of the response, or a dict containing error
        details.
      response: The original response object.

    Raises:
      ClientError: If the status code is in the 4xx range.
      ServerError: If the status code is in the 5xx range.
      APIError: For other error status codes.
    """
    if 400 <= status_code < 500:
      raise ClientError(status_code, response_json, response)
    elif 500 <= status_code < 600:
      raise ServerError(status_code, response_json, response)
    else:
      raise cls(status_code, response_json, response)


class ClientError(APIError):
  """Client error raised by the GenAI API."""
  pass


class ServerError(APIError):
  """Server error raised by the GenAI API."""
  pass


class UnknownFunctionCallArgumentError(ValueError):
  """Raised when the function call argument cannot be converted to the parameter annotation."""
  pass


class UnsupportedFunctionError(ValueError):
  """Raised when the function is not supported."""
  pass


class FunctionInvocationError(ValueError):
  """Raised when the function cannot be invoked with the given arguments."""
  pass


class UnknownApiResponseError(ValueError):
  """Raised when the response from the API cannot be parsed as JSON."""
  pass

ExperimentalWarning = _common.ExperimentalWarning
